Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid unpad/pad repeated calls when use_cache=False #5

Open
wants to merge 1 commit into
base: add-flash-attn-2
Choose a base branch
from

Conversation

fxmarty
Copy link
Collaborator

@fxmarty fxmarty commented Sep 12, 2023

As per title. The difference is quite large. This is only done out of curiosity, cc @younesbelkada

Note: Speedup over the base PR is expected only in case of batch_size > 1 when padding / masked tokens are used. In the benchmark below, we use a padding percentage of 30%.

This is on a single A100 for meta-llama/Llama-2-7b-hf.

Forward only with no_grad mode

batch_size=4, len=1000

Transformers latency (ms) Younes PR latency (ms) This PoC latency (ms) Younes speedup / transformers This PoC speedup / transformers
395 321 237 1.228 1.665

batch_size=4, len=2000

Transformers latency (ms) Younes PR latency (ms) This PoC latency (ms) Younes speedup / transformers This PoC speedup / transformers
OOM 627 475 / /

batch_size=8, len=500

Transformers latency (ms) Younes PR latency (ms) This PoC latency (ms) Younes speedup / transformers This PoC speedup / transformers
357 318 231 1.120 1.545

batch_size=2, len=4000

Transformers latency (ms) Younes PR latency (ms) This PoC latency (ms) Younes speedup / transformers This PoC speedup / transformers
OOM 327 262 / /

forward + backward

batch_size=4, len=1500

Transformers latency (ms) Younes PR latency (ms) This PoC latency (ms) Younes speedup / transformers This PoC speedup / transformers
OOM 1353 1062 / /

batch_size=2, len=3000

Transformers latency (ms) Younes PR latency (ms) This PoC latency (ms) Younes speedup / transformers This PoC speedup / transformers
OOM 1422 1178 / /

batch_size=2, len=1000

Transformers latency (ms) Younes PR latency (ms) This PoC latency (ms) Younes speedup / transformers This PoC speedup / transformers
580 506 423 1.146 1.370

@fxmarty
Copy link
Collaborator Author

fxmarty commented Sep 12, 2023

Benchmark script out of completeness: https://pastebin.com/zWE9Aedr

Copy link
Owner

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look overall great to me! I wonder if we can add padding_mask inside flash_kwargs. For me we are making the attention forward signature a bit more complicated but for the speedup we get it is great I think. Can you also confirm generate with use_cache works fine here?
I would also like to have a review from @ArthurZucker before merging this

@fxmarty
Copy link
Collaborator Author

fxmarty commented Sep 14, 2023

This is a draft by the way, I just wanted to get results. I'm not sure if it is very fit for transformers though, with the modifications directly in the LlamaModel forward?

generate with use_cache=True can unfortunately not use this path, because the KV cache implementation in transformers makes the size of keys and values have a dynamic shape in the attention, where basically the cumulative sequence length, max sequence length always change, with length different than the hidden_states that is e.g. simply 1 in a decoding phase.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's nice that you found a way to do this, but as you said not very transformers like and bloated, especially if this is un-usable with use_cache=True 😢

younesbelkada pushed a commit that referenced this pull request Mar 14, 2024
younesbelkada added a commit that referenced this pull request Mar 18, 2024
* Cohere Model Release (#1)

Cohere Model Release

* Remove unnecessary files and code (#2)

Some cleanup

* Delete cohere-model directory (#3)

* Make Fix (#5)

* Pr fixes (#6)

* fixes for pr

* pr fixes for the format

* pr fixes for the format

* src/transformers/models/auto/tokenization_auto.py

* Tokenizer test (huggingface#8)

* tokenizer test

* format fix

* Adding Docs and other minor changes (huggingface#7)

* Add modeling tests (huggingface#9)

* Smol Fix (huggingface#11)

* tokenization tests are fixed

* format fixes

* fix pr doc tests

* fix pr doc tests

* fix pr doc tests

* fix pr style check

* small changes in cohere.md

* FIX: Address final comments for transformers integration (huggingface#13)

* fix modeling final nits and add proper test file

* for now leave empty tests

* add integration test

* push new test

* fix modeling cohere (huggingface#14)

* Update chat templates to use the new API (huggingface#15)

---------

Co-authored-by: ahmetustun <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: Matt <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants